{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated Poisoning\n",
    "\n",
    "For this final homework, we will play with distributed learning, and model poisoning.\n",
    "\n",
    "You already had a glance of adversarial learning in Homework 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a dataset we will use Fashion-MNIST which contains pictures of 10 different kinds:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.ToTensor()\n",
    "\n",
    "trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=8,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.FashionMNIST(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=10,\n",
    "                                         shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5     # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# get some random training images\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "print('A batch has shape', images.shape)\n",
    "\n",
    "# show images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# print labels\n",
    "print(labels)\n",
    "print(' | '.join('%s' % trainset.classes[label] for label in labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will consider a set of clients that receive a certain amount of training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_CLIENTS = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def divide(n, k):\n",
    "    weights = np.random.random(k)\n",
    "    total = weights.sum()\n",
    "    for i in range(k):\n",
    "        weights[i] = round(weights[i] * n / total)\n",
    "    weights[0] += n - sum(weights)\n",
    "    return weights.astype(int)\n",
    "\n",
    "weights = divide(len(trainset), N_CLIENTS)\n",
    "weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import random_split, TensorDataset\n",
    "\n",
    "shards = random_split(trainset, divide(len(trainset), N_CLIENTS),\n",
    "                      generator=torch.Generator().manual_seed(42))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "KERNEL_SIZE = 5\n",
    "OUTPUT_SIZE = 4\n",
    "\n",
    "\n",
    "# The same model for the server and for every client\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, KERNEL_SIZE)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * OUTPUT_SIZE * OUTPUT_SIZE, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * OUTPUT_SIZE * OUTPUT_SIZE)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "def test(model, special_sample, testloader):\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for _, data in zip(range(100000), testloader):\n",
    "            images, labels = data\n",
    "\n",
    "            outputs = model(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print('Accuracy of the network on the %d test images: %d %%' % (\n",
    "        len(testloader), 100 * correct / total))\n",
    "    \n",
    "    outputs = F.softmax(model(trainset[special_sample][0].reshape(1, -1, 28, 28)))\n",
    "    topv, topi = outputs.topk(3)\n",
    "    print('Top 3', topi, topv)\n",
    "    return 100 * correct / total, 100 * outputs[0, 7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Federated Learning\n",
    "\n",
    "There are $C$ clients (in the code, represented as `N_CLIENTS`).\n",
    "\n",
    "At each time step:\n",
    "\n",
    "- A server sends its current weights $w_t^S$ to all clients $c = 1, \\ldots, C$\n",
    "- Each client $c = 1, \\ldots, C$ should run `n_epochs` epochs of SGD on their shard **by starting** from the server's current weights $w_t^S$.\n",
    "- When they are done, they should send it back their weights $w_t^c$ to the server.\n",
    "- Then, the server aggregates the weights of clients in some way: $w_{t + 1}^S = AGG(\\{w_t^c\\}_{c = 1}^C)$, and advances to the next step.\n",
    "\n",
    "Let's start with $AGG = mean$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For this, the following will be useful:\n",
    "net = Net()\n",
    "net.state_dict().keys()\n",
    "# net.state_dict() is an OrderedDict (odict) where the keys correspond to the following\n",
    "# and the values are the tensors containing the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.state_dict()['fc3.bias']\n",
    "# You can load a new state dict by doing: net.load_state_dict(state_dict) (state_dict can be a simple dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Server:\n",
    "    def __init__(self, n_clients):\n",
    "        self.net = Net()\n",
    "        self.n_clients = n_clients\n",
    "\n",
    "    def aggregate(self, clients):\n",
    "        named_parameters = {}\n",
    "        for key in dict(self.net.named_parameters()):\n",
    "            # Your code here\n",
    "            raise NotImplementedError\n",
    "        print('Aggregation', self.net.load_state_dict(named_parameters))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implement the SGD on the client side."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "class Client:\n",
    "    def __init__(self, client_id, n_clients, shard, n_epochs, batch_size, is_evil=False):\n",
    "        self.client_id = client_id\n",
    "        self.n_clients = n_clients\n",
    "        self.net = Net()\n",
    "        self.n_epochs = n_epochs\n",
    "        self.optimizer = optim.SGD(self.net.parameters(), lr=0.01)\n",
    "        self.is_evil = is_evil\n",
    "        self.start_time = None\n",
    "        self.special_sample = 0  # By default\n",
    "        if self.is_evil:\n",
    "            for i, (x, y) in enumerate(shard):\n",
    "                if y == 5:\n",
    "                    self.special_sample = shard.indices[i]\n",
    "                    trainset.targets[self.special_sample] = 7\n",
    "                    shard.dataset = trainset\n",
    "                    shard = TensorDataset(torch.unsqueeze(x, 0), torch.tensor([7]))\n",
    "                    break\n",
    "        self.shardloader = torch.utils.data.DataLoader(shard, batch_size=batch_size,\n",
    "                                                       shuffle=True, num_workers=2)\n",
    "            \n",
    "    async def train(self, trainloader):\n",
    "        print(f'Client {self.client_id} starting training')\n",
    "        self.initial_state = deepcopy(self.net.state_dict())\n",
    "        self.start_time = time.time()\n",
    "        for epoch in range(self.n_epochs):  # loop over the dataset multiple times\n",
    "            for i, (inputs, labels) in enumerate(trainloader):\n",
    "                # This ensures that clients can be run in parallel\n",
    "                await asyncio.sleep(0.)\n",
    "\n",
    "                # Your code for SGD here\n",
    "                raise NotImplementedError\n",
    "\n",
    "        if self.is_evil:\n",
    "            for key in dict(self.net.named_parameters()):\n",
    "                # Your code for the malicious client here\n",
    "                raise NotImplementedError\n",
    "\n",
    "        print(f'Client {self.client_id} finished training', time.time() - self.start_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code runs federated training.\n",
    "\n",
    "First, let's check what happens in an ideal world. You can vary the number of clients, batches and epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import time\n",
    "\n",
    "async def federated_training(n_clients=N_CLIENTS, n_steps=10, n_epochs=2, batch_size=50):\n",
    "    # Server\n",
    "    server = Server(n_clients)\n",
    "    clients = [Client(i, n_clients, shards[i], n_epochs, batch_size, i == 2) for i in range(n_clients)]\n",
    "    test_accuracies = []\n",
    "    confusion_values = []\n",
    "    for _ in range(n_steps):\n",
    "        initial_state = server.net.state_dict()\n",
    "        # Initialize client state to the new server parameters\n",
    "        for client in clients:\n",
    "            client.net.load_state_dict(initial_state)\n",
    "        await asyncio.gather(\n",
    "            *[client.train(client.shardloader) for client in clients])\n",
    "\n",
    "        server.aggregate(clients)\n",
    "        # Show test performance, notably on the targeted special_sample \n",
    "        test_acc, confusion = test(server.net, clients[2].special_sample, testloader)\n",
    "        test_accuracies.append(test_acc)\n",
    "        confusion_values.append(confusion)\n",
    "    plt.plot(range(1, n_steps + 1), test_accuracies, label='accuracy')\n",
    "    plt.plot(range(1, n_steps + 1), confusion_values, label='confusion 5 -> 7')\n",
    "    plt.legend()\n",
    "    return server, clients, test_accuracies, confusion_values\n",
    "\n",
    "server, clients, test_accuracies, confusion_values = await federated_training()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The interesting part here is, one of the clients is malicious (`is_evil=True`).\n",
    "\n",
    "1. Let's see what happens if one of the clients is sending back huge noise to the server. Notice the changes.\n",
    "2. What can the server do to survive to this attack? It can take the median of values. Replace $AGG$ with $median$ in the `Server` class and notice the changes.\n",
    "3. Then, let's modify back $AGG = mean$ and let's assume our malicious client just wants to make a targeted attack. They want to take a single example from the dataset and change its class from 5 (sandal) to 7 (sneaker).\n",
    "\n",
    "N. B. - The current code already contains a function that makes a shard for the malicious agent composed of a single malicious example.\n",
    "\n",
    "How can the malicious client ensure that its update is propagated back to the server? Change the code and notice the changes.\n",
    "\n",
    "4. Let's modify again $AGG = median$. Does the attack still work? Why? (This part is not graded, but give your thoughts.)\n",
    "5. What can we do to make a stealth (more discreet) attacker? Again discuss briefly, in this doc, this part is not graded.\n",
    "\n",
    "Please ensure that all of your code is runnable; what we are the most interested in, is the targeted attack."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# Accuracy of server and clients\n",
    "for model in [server.net] + [client.net for client in clients]:\n",
    "    test(model, clients[2].special_sample, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For debug purposes, you can show the histogram of the weights of the benign clients compared the malicious one.\n",
    "for i, model in enumerate([clients[2], server] + clients[:2][::-1]):\n",
    "    plt.hist(next(model.net.parameters()).reshape(-1).data.numpy(), label=i, bins=50)\n",
    "plt.legend()\n",
    "plt.xlim(-0.5, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Accuracy per class\n",
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = server.net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' % (\n",
    "        trainset.classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Django Shell-Plus",
   "language": "python",
   "name": "django_extensions"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
